{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 1: Generate hard-links for consecutive ST slices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### loading packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n", "################################################################################\n", "WARNING!\n", "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n", "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n", "to learn more and leave feedback.\n", "################################################################################\n", "\n", " deprecation_warning()\n", "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import scipy\n", "import os\n", "import pickle\n", "import sys\n", "\n", "# Get the parent directory of the current script\n", "parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))\n", "\n", "# Add the parent directory to the system path\n", "sys.path.insert(0, parent_dir)\n", "\n", "from utils_local_alignment import (\n", " build_args_ST,\n", " create_optimizer\n", ")\n", "\n", "from models import build_model_ST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### HP setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "args = build_args_ST()\n", "args.max_epoch=2000\n", "args.max_epoch_triplet=500\n", "args.dataset_name=\"DLPFC\"\n", "args.section_ids=[\"151507\", \"151508\"]\n", "args.num_class=7\n", "args.num_hidden=\"512,32\"\n", "args.alpha_l=1\n", "args.lam=1 \n", "args.loss_fn=\"sce\" \n", "args.mask_rate=0.5 \n", "args.in_drop=0 \n", "args.attn_drop=0 \n", "args.remask_rate=0.1\n", "args.seeds=[2024] \n", "args.num_remasking=1 \n", "args.hvgs=5000 \n", "args.dataset=\"DLPFC\" \n", "args.consecutive_prior=1 \n", "args.lr=0.001\n", "args.scheduler = True\n", "args.st_data_dir=\"../../spatial_benchmarking/benchmarking_data/DLPFC12\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### data loader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name: DLPFC\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/anndata/_core/anndata.py:1840: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", " utils.warn_names_duplicates(\"var\")\n", "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/anndata/_core/anndata.py:1840: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", " utils.warn_names_duplicates(\"var\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "------Calculating spatial graph...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/maiziezhou_lab/yunfei/Projects/MaskGraphene/datasets/data_proc.py:80: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)\n", "/maiziezhou_lab/yunfei/Projects/MaskGraphene/datasets/data_proc.py:81: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The graph contains 24762 edges, 4221 cells.\n", "5.8664 neighbors per cell on average.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/anndata/_core/anndata.py:1840: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", " utils.warn_names_duplicates(\"var\")\n", "/home/huy21/anaconda3/envs/MaskGraphene/lib/python3.9/site-packages/anndata/_core/anndata.py:1840: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", " utils.warn_names_duplicates(\"var\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "------Calculating spatial graph...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/maiziezhou_lab/yunfei/Projects/MaskGraphene/datasets/data_proc.py:80: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)\n", "/maiziezhou_lab/yunfei/Projects/MaskGraphene/datasets/data_proc.py:81: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The graph contains 25692 edges, 4381 cells.\n", "5.8644 neighbors per cell on average.\n" ] } ], "source": [ "from local_alignment_main import local_alignment_loader\n", "import scanpy as sc\n", "import anndata\n", "import numpy as np\n", "import dgl\n", "import torch\n", "import paste\n", "\n", "section_ids = args.section_ids\n", "exp_fig_dir = args.exp_fig_dir\n", "st_data_dir = args.st_data_dir\n", "dataset_name = args.dataset\n", "\n", "graph, num_features, ad_concat = local_alignment_loader(section_ids=section_ids, hvgs=args.hvgs, st_data_dir=st_data_dir, dataname=dataset_name)\n", "args.num_features = num_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### model setup" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== Use sce_loss and alpha_l=1 ===\n", "num_encoder_params: 1170016, num_decoder_params: 1176670, num_params_in_total: 2384362\n" ] } ], "source": [ "model = build_model_ST(args)\n", "# print(model)\n", "\n", "device = args.device if args.device >= 0 else \"cpu\"\n", "\n", "model.to(device)\n", "optimizer = create_optimizer(args.optimizer, model, args.lr, args.weight_decay)\n", "\n", "if args.scheduler:\n", " scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / args.max_epoch) ) * 0.5\n", " scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)\n", "else:\n", " scheduler = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### masked reconstruction loss training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "from local_alignment_main import run_local_alignment\n", "\n", "max_epoch = args.max_epoch\n", "max_epoch_triplet = args.max_epoch_triplet\n", "num_class = args.num_class\n", "alpha_value = args.alpha_value\n", "\n", "\"\"\"training\"\"\"\n", "batchlist_, ad_concat = run_local_alignment(graph, model, device, ad_concat, section_ids, max_epoch=max_epoch, max_epoch_triplet=max_epoch_triplet, optimizer=optimizer, scheduler=scheduler, logger=None, num_class=num_class, use_mnn=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ot alignment to generate hard-links" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gpu is available, using gpu.\n", "gpu is available, using gpu.\n", "gpu is available, using gpu.\n", "gpu is available, using gpu.\n", "gpu is available, using gpu.\n", "gpu is available, using gpu.\n", "gpu is available, using gpu.\n" ] } ], "source": [ "slice1 = batchlist_[0]\n", "slice2 = batchlist_[1]\n", "\n", "global_PI = np.zeros((len(slice1.obs.index), len(slice2.obs.index)))\n", "slice1_idx_mapping = {}\n", "slice2_idx_mapping = {}\n", "for i in range(len(slice1.obs.index)):\n", " slice1_idx_mapping[slice1.obs.index[i]] = i\n", "for i in range(len(slice2.obs.index)):\n", " slice2_idx_mapping[slice2.obs.index[i]] = i\n", "\n", "for i in range(num_class):\n", " print(\"run for cluster:\", i)\n", " subslice1 = slice1[slice1.obs['mclust']==i+1]\n", " subslice2 = slice2[slice2.obs['mclust']==i+1]\n", " if subslice1.shape[0]>0 and subslice2.shape[0]>0:\n", " if subslice1.shape[0]>1 and subslice2.shape[0]>1: \n", " pi00 = paste.match_spots_using_spatial_heuristic(subslice1.obsm['spatial'], subslice2.obsm['spatial'], use_ot=True)\n", " local_PI = paste.pairwise_align(subslice1, subslice2, alpha=alpha_value, dissimilarity='kl', use_rep=None, norm=True, verbose=True, G_init=pi00, use_gpu = True, backend = ot.backend.TorchBackend())\n", " else: # if there is only one spot in a slice, spatial dissimilarity can't be normalized\n", " local_PI = paste.pairwise_align(subslice1, subslice2, alpha=alpha_value, dissimilarity='kl', use_rep=None, norm=False, verbose=True, G_init=None, use_gpu = True, backend = ot.backend.TorchBackend())\n", " for ii in range(local_PI.shape[0]):\n", " for jj in range(local_PI.shape[1]):\n", " global_PI[slice1_idx_mapping[subslice1.obs.index[ii]]][slice2_idx_mapping[subslice2.obs.index[jj]]] = local_PI[ii][jj]\n", " # cluster_matrix[slice1_idx_mapping[subslice1.obs.index[ii]]][slice2_idx_mapping[subslice2.obs.index[jj]]] = i\n", " else:\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### save/load Hard-links" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "file_name = section_ids[0]+'_'+section_ids[1] +'_'+str(alpha_value)\n", "mapping_mat = scipy.sparse.csr_matrix(global_PI)\n", "file = open(os.path.join(exp_fig_dir, file_name+\"_HL.pickle\"),'wb')\n", "pickle.dump(mapping_mat, file)\n", "\n", "\n", "new_slices = paste.stack_slices_pairwise(batchlist_, mapping_mat)\n", "for i, L in enumerate(new_slices):\n", " spatial_data = L.obsm['spatial']\n", "\n", " output_path = os.path.join(exp_fig_dir, f\"coordinates_{section_ids[i]}.csv\")\n", " pd.DataFrame(spatial_data).to_csv(output_path, index=False)\n", " print(f\"Saved spatial data for slice {i} to {output_path}\")" ] } ], "metadata": { "kernelspec": { "display_name": "MaskGraphene", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 2 }